#!/usr/bin/env python
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT


import os
import json
import argparse
import itertools
import multiprocessing
import concurrent.futures
from pathlib import Path
import logging
import importlib.util


def _import_validation_utils():
    """Import validation utilities from commons directory."""
    current_dir = os.path.dirname(os.path.abspath(__file__))
    parent_dir = os.path.dirname(current_dir)

    # Load the module dynamically
    spec = importlib.util.spec_from_file_location(
        "validation_utils",
        os.path.join(parent_dir, "commons", "gemm_validation_utils.py"),
    )
    validation_utils = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(validation_utils)

    return validation_utils


# Import validation functions
_validation_utils = _import_validation_utils()
is_tile_config_valid = _validation_utils.is_tile_config_valid
is_trait_combination_valid = _validation_utils.is_trait_combination_valid
get_dtype_string = _validation_utils.get_dtype_string
get_abcd_layouts = _validation_utils.get_abcd_layouts

logging.basicConfig(level=logging.INFO)


class GemmMultiDKernelBuilder:
    def __init__(
        self,
        working_path,
        gpu_target,
        datatype,
        layout,
        elementwise_function,
        config_json=None,
    ):
        self.working_path = Path(working_path)
        self.gpu_target = gpu_target
        self.datatype = datatype
        self.layout = layout
        self.elementwise_function = elementwise_function
        self.config_json = config_json

        # Create working directory if it doesn't exist
        self.working_path.mkdir(parents=True, exist_ok=True)

        # Load configuration
        if config_json and os.path.exists(config_json):
            with open(config_json, "r") as f:
                self.config = json.load(f)

    def write_kernel_list(self):
        """Write kernel list to file for CMake to read (with comprehensive validation)"""
        # Get configurations using comprehensive validation
        tile_configs = self._get_tile_configs(fast_mode=False)
        trait_combos = self._generate_trait_combinations()

        kernel_list = []
        for tile_config in tile_configs:
            for trait_combo in trait_combos:
                (
                    pipeline,
                    epilogue,
                    scheduler,
                    pad_m,
                    pad_n,
                    pad_k,
                    persistent,
                ) = trait_combo

                # Create kernel name with proper boolean capitalization
                kernel_name = f"gemm_multi_d_{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}"

                # Create tile configuration string
                tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_"
                tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_"
                tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}"

                kernel_name += f"_{tile_str}"

                kernel_list.append(
                    {
                        "name": kernel_name,
                        "tile_config": tile_config,
                        "trait_combo": trait_combo,
                    }
                )

        # Write kernel count
        with open(self.working_path / "gemm_multi_d_kernel_count.txt", "w") as f:
            f.write(str(len(kernel_list)))

        # Write kernel list
        with open(self.working_path / "gemm_multi_d_kernel_list.txt", "w") as f:
            for kernel in kernel_list:
                # Format: kernel_name|tile_config|trait_combo
                tile_config = kernel["tile_config"]
                trait_combo = kernel["trait_combo"]

                tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_"
                tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_"
                tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}"

                trait_str = (
                    f"{trait_combo[0]}_{trait_combo[1]}_{trait_combo[2]}_"
                    + "_".join(str(x) for x in trait_combo[3:])
                )

                f.write(f"{kernel['name']}|{tile_str}|{trait_str}\n")

        print(f"Listed {len(kernel_list)} kernel configurations")

    def _get_tile_configs(self, fast_mode=False):
        """Get tile configurations for the current datatype and layout"""
        tile_config = self.config["tile_config"]

        # Generate values in the config if default range is given
        if tile_config.get("tile_m").get("values") is None:
            tile_config.get("tile_m")["values"] = self._generate_values(
                tile_config.get("tile_m").get("min"),
                tile_config.get("tile_m").get("max"),
                tile_config.get("tile_m").get("step"),
            )
        if tile_config.get("tile_n").get("values") is None:
            tile_config.get("tile_n")["values"] = self._generate_values(
                tile_config.get("tile_n").get("min"),
                tile_config.get("tile_n").get("max"),
                tile_config.get("tile_n").get("step"),
            )
        if tile_config.get("tile_k").get("values") is None:
            tile_config.get("tile_k")["values"] = self._generate_values(
                tile_config.get("tile_k").get("min"),
                tile_config.get("tile_k").get("max"),
                tile_config.get("tile_k").get("step"),
            )

        # Get all possible values for each parameter
        tile_m_values = tile_config.get("tile_m").get("values")
        tile_n_values = tile_config.get("tile_n").get("values")
        tile_k_values = tile_config.get("tile_k").get("values")
        warp_m_values = tile_config.get("warp_m").get("values")
        warp_n_values = tile_config.get("warp_n").get("values")
        warp_k_values = tile_config.get("warp_k").get("values")
        warp_tile_m_values = tile_config.get("warp_tile_m").get("values")
        warp_tile_n_values = tile_config.get("warp_tile_n").get("values")
        warp_tile_k_values = tile_config.get("warp_tile_k").get("values")

        # Generate all combinations
        configs = []
        for tile_m in tile_m_values:
            for tile_n in tile_n_values:
                for tile_k in tile_k_values:
                    for warp_m in warp_m_values:
                        for warp_n in warp_n_values:
                            for warp_k in warp_k_values:
                                for warp_tile_m in warp_tile_m_values:
                                    for warp_tile_n in warp_tile_n_values:
                                        for warp_tile_k in warp_tile_k_values:
                                            # Validate configuration
                                            if self._validate_tile_config(
                                                tile_m,
                                                tile_n,
                                                tile_k,
                                                warp_m,
                                                warp_n,
                                                warp_k,
                                                warp_tile_m,
                                                warp_tile_n,
                                                warp_tile_k,
                                                fast_mode=fast_mode,
                                            ):
                                                configs.append(
                                                    {
                                                        "tile_m": tile_m,
                                                        "tile_n": tile_n,
                                                        "tile_k": tile_k,
                                                        "warp_m": warp_m,
                                                        "warp_n": warp_n,
                                                        "warp_k": warp_k,
                                                        "warp_tile_m": warp_tile_m,
                                                        "warp_tile_n": warp_tile_n,
                                                        "warp_tile_k": warp_tile_k,
                                                    }
                                                )
        return configs

    def _generate_values(self, min_val, max_val, step):
        """Generate a list of values from min to max with the given step"""
        values = []
        val = min_val
        while val <= max_val:
            values.append(val)
            val += step
        return values

    def _validate_tile_config(
        self,
        tile_m,
        tile_n,
        tile_k,
        warp_m,
        warp_n,
        warp_k,
        warp_tile_m,
        warp_tile_n,
        warp_tile_k,
        pipeline="compv4",  # Default pipeline for validation
        fast_mode=False,  # Add fast mode option
    ):
        """Validate that tile configuration is reasonable"""
        if fast_mode:
            # Fast validation for listing - only basic sanity checks
            if tile_m <= 0 or tile_n <= 0 or tile_k <= 0:
                return False
            if warp_m <= 0 or warp_n <= 0 or warp_k <= 0:
                return False
            if warp_tile_m <= 0 or warp_tile_n <= 0 or warp_tile_k <= 0:
                return False

            # Basic divisibility check
            if tile_m % (warp_m * warp_tile_m) != 0:
                return False
            if tile_n % (warp_n * warp_tile_n) != 0:
                return False
            if tile_k % (warp_k * warp_tile_k) != 0:
                return False

            return True
        else:
            # Full validation for generation
            # Determine data types for validation
            a_datatype = self.datatype
            b_datatype = self.datatype
            c_datatype = self.datatype

            layout = self.layout

            # Special handling for certain data types
            if self.datatype in ["fp8", "bf8"]:
                c_datatype = "fp16"

            # Use the comprehensive validation function
            return is_tile_config_valid(
                tile_m,
                tile_n,
                tile_k,
                warp_m,
                warp_n,
                warp_k,
                warp_tile_m,
                warp_tile_n,
                warp_tile_k,
                a_datatype,
                b_datatype,
                c_datatype,
                pipeline,
                layout,
                self.gpu_target,
            )

    def _generate_trait_combinations(self):
        """Generate all combinations of traits"""

        trait_config = self.config["trait_config"]

        pipelines = trait_config.get("pipeline").get("values")
        epilogues = trait_config.get("epilogue").get("values")
        schedulers = trait_config.get("scheduler").get("values")
        pad_m_values = trait_config.get("pad_m").get("values")
        pad_n_values = trait_config.get("pad_n").get("values")
        pad_k_values = trait_config.get("pad_k").get("values")
        persistent_values = trait_config.get("persistent").get("values")

        all_combinations = list(
            itertools.product(
                pipelines,
                epilogues,
                schedulers,
                pad_m_values,
                pad_n_values,
                pad_k_values,
                persistent_values,
            )
        )

        # Filter out unsupported trait combinations
        combinations = []
        for combo in all_combinations:
            pipeline, epilogue, scheduler = combo[:3]
            if is_trait_combination_valid(pipeline, epilogue, scheduler):
                combinations.append(combo)
            else:
                logging.debug(
                    f"Skipping unsupported trait combination: {pipeline}-{epilogue}-{scheduler}"
                )
        return combinations

    def _generate_kernel_instance(
        self, tile_config, trait_combo, k_block_per_cu, is_header=True
    ):
        """Generate a single kernel instance"""
        (
            pipeline,
            epilogue,
            scheduler,
            pad_m,
            pad_n,
            pad_k,
            persistent,
        ) = trait_combo

        # Create kernel name with proper boolean capitalization
        kernel_name = f"gemm_multi_d_{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}"

        # Create tile configuration string
        tile_str = (
            f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_"
        )
        tile_str += (
            f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_"
        )
        tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}"

        kernel_name += f"_{tile_str}"

        # Map pipeline names to the correct pipeline implementation
        pipeline_impl_map = {
            "mem": "ck_tile::GemmPipelineAgBgCrMem",
            "compv3": "ck_tile::GemmPipelineAgBgCrCompV3",
            "compv4": "ck_tile::GemmPipelineAgBgCrCompV4",
        }

        # Map pipeline names to base pipeline for hot loop detection
        base_pipeline_map = {
            "mem": "ck_tile::BaseGemmPipelineAgBgCrMem",
            "compv3": "ck_tile::BaseGemmPipelineAgBgCrCompV3",
            "compv4": "ck_tile::BaseGemmPipelineAgBgCrCompV4",
        }

        # Map scheduler names to the correct enum values
        scheduler_type_map = {
            "intrawave": "ck_tile::GemmPipelineScheduler::Intrawave",
            "interwave": "ck_tile::GemmPipelineScheduler::Interwave",
            "default": "ck_tile::GemmPipelineScheduler::Default",
        }

        # Determine accumulator type based on datatype
        acc_type = "float"

        # Determine output type
        c_type = self.datatype
        if self.datatype in ["fp8", "bf8"]:
            c_type = "fp16"

        # Determine layouts based on self.layout
        a_layout, b_layout, c_layout, ds_layout = get_abcd_layouts(self.layout)

        # Generate kernel instance code using the correct API
        pragma_line = "#pragma once\n" if is_header else ""
        instance_code = f"""// Generated kernel instance for {kernel_name}
{pragma_line}
#include <cstdint>
#include <utility>
#include <tuple>
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"

using ADataType = {get_dtype_string(self.datatype)};
using BDataType = {get_dtype_string(self.datatype)};
using AccDataType = {acc_type};
using CDataType = {get_dtype_string(c_type)};
using D0DataType = {get_dtype_string(self.datatype)};
using D1DataType = {get_dtype_string(self.datatype)};
using DsDataType = ck_tile::tuple<D0DataType, D1DataType>;

using ALayout = {a_layout};
using BLayout = {b_layout};
using CLayout = {c_layout};
using D0Layout = {ds_layout[0]};
using D1Layout = {ds_layout[1]};
using DsLayout = ck_tile::tuple<D0Layout, D1Layout>;

using ElementWiseFn = ck_tile::element_wise::{self.elementwise_function};

// Kernel name for display
constexpr const char* KERNEL_NAME = "{kernel_name}";

// Wrapper for simplified launch interface
struct SelectedKernel {{
    // Tile configuration
    static constexpr ck_tile::index_t BlockSize = 256;
    static constexpr ck_tile::index_t TileM = {tile_config["tile_m"]};
    static constexpr ck_tile::index_t TileN = {tile_config["tile_n"]};
    static constexpr ck_tile::index_t TileK = {tile_config["tile_k"]};
    static constexpr ck_tile::index_t WarpPerBlock_M = {tile_config["warp_m"]};
    static constexpr ck_tile::index_t WarpPerBlock_N = {tile_config["warp_n"]};
    static constexpr ck_tile::index_t WarpPerBlock_K = {tile_config["warp_k"]};
    static constexpr ck_tile::index_t WarpTileM = {tile_config["warp_tile_m"]};
    static constexpr ck_tile::index_t WarpTileN = {tile_config["warp_tile_n"]};
    static constexpr ck_tile::index_t WarpTileK = {tile_config["warp_tile_k"]};

    // Traits
    static constexpr bool kPadM = {"true" if pad_m in [True, "true"] else "false"};
    static constexpr bool kPadN = {"true" if pad_n in [True, "true"] else "false"};
    static constexpr bool kPadK = {"true" if pad_k in [True, "true"] else "false"};
    
    static constexpr bool DoubleSmemBuffer = {"true" if pipeline == "compv4" else "false"};
    static constexpr bool TransposeC = false;

    // Tile shape
    using TileShape = ck_tile::TileGemmShape<
        ck_tile::sequence<TileM, TileN, TileK>,
        ck_tile::sequence<WarpPerBlock_M, WarpPerBlock_N, WarpPerBlock_K>,
        ck_tile::sequence<WarpTileM, WarpTileN, WarpTileK>>;
    
    // Tile partitioner
    using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<TileShape, 8, 4>;
    
    // Traits
    using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
    
    // Pipeline problem
    using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
        ADataType,
        BDataType,
        AccDataType,
        TileShape,
        Traits>;
    
    // Base pipeline for hot loop detection
    using BaseGemmPipeline = {base_pipeline_map.get(pipeline)}<GemmPipelineProblem>;

    static float launch(const ck_tile::GemmMultiDHostArgs<DsDataType::size()>& args, const ck_tile::stream_config& stream) {{
        const ck_tile::index_t k_grain = args.k_batch * TileK;
        const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * TileK;
        const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
        const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
        const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
        
        float ave_time{{0}};

        const auto Run = [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {{
            constexpr bool has_hot_loop_v = has_hot_loop_.value;
            constexpr auto tail_number_v = tail_number_.value;
            constexpr auto scheduler = {scheduler_type_map.get(scheduler)};
            [[maybe_unused]] constexpr auto memory_operation = memory_operation_.value;

            using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
                ADataType,
                BDataType,
                AccDataType,
                TileShape,
                ck_tile::TileGemmUniversalTraits<kPadM, kPadN, kPadK, DoubleSmemBuffer,
                                                ALayout, BLayout, CLayout, TransposeC>,
                scheduler,
                has_hot_loop_v,
                tail_number_v>;
            
            using GemmPipeline = {pipeline_impl_map.get(pipeline)}<UniversalGemmProblem>;
            
            // Epilogue
"""

        # Add epilogue configuration based on type
        if epilogue == "cshuffle":
            instance_code += """            using EpilogueProblem = ck_tile::CShuffleEpilogueProblem<
                ADataType,
                BDataType,
                DsDataType,
                AccDataType,
                CDataType,
                DsLayout,
                CLayout,
                ElementWiseFn,
                TilePartitioner::MPerBlock,  // kM_
                TilePartitioner::NPerBlock,  // kN_
                WarpPerBlock_M,              // MWave_
                WarpPerBlock_N,              // NWave_
                WarpTileM,                   // MPerXdl_
                WarpTileN,                   // NPerXdl_
                WarpTileK,                   // KPerXdl_
                TransposeC,                  // isCTransposed_
                memory_operation>;           // MemoryOperation_ 
       
            using GemmEpilogue = ck_tile::CShuffleEpilogue<EpilogueProblem>;
"""
        else:  # default epilogue
            instance_code += """            using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem<
                ADataType,
                BDataType,
                DsDataType,
                AccDataType,
                CDataType,
                DsLayout,
                CLayout,
                ElementWiseFn,
                TilePartitioner::MPerBlock,  // kM_
                TilePartitioner::NPerBlock,  // kN_
                kPadM,
                kPadN,
                WarpTileM,  // kMPerXdl_
                WarpTileN,  // kNPerXdl_
                WarpTileK,  // kKPerXdl_
                TransposeC>;  // isCTransposed_
            
            using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<EpilogueProblem>;
"""

        instance_code += f"""
            
            // Kernel type
            using GemmKernelMultiD = ck_tile::GemmKernelMultiD<TilePartitioner, GemmPipeline, GemmEpilogue>;
            
            // Make kernel arguments
            auto kargs = GemmKernelMultiD::MakeKernelArgs(args);
            
            if (!GemmKernelMultiD::IsSupportedArgument(kargs)) {{
                throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!");
            }}
            
            // Get grid and block sizes
            const dim3 grids = GemmKernelMultiD::GridSize(args.M, args.N, args.k_batch);
            const dim3 blocks = GemmKernelMultiD::BlockSize();
            
            if(stream.log_level_ > 0) {{
                std::cout << "Launching kernel with args: " << GemmKernelMultiD::GetName() << '\\n'
                          << "grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}"
                          << ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}"
                          << std::endl;
            }}
            
            // Launch kernel
            constexpr int kBlockPerCu = {k_block_per_cu};
            ave_time = ck_tile::launch_kernel(
                stream,
                ck_tile::make_kernel<kBlockPerCu>(GemmKernelMultiD{{}}, grids, blocks, 0, kargs));
            
            return ave_time;
        }};

        const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {{
            if(args.k_batch == 1) {{
                Run(has_hot_loop_,
                    tail_number_,
                    ck_tile::integral_constant<ck_tile::memory_operation_enum,
                                            ck_tile::memory_operation_enum::set>{{}});
            }} else {{
                Run(has_hot_loop_,
                    tail_number_,
                    ck_tile::integral_constant<ck_tile::memory_operation_enum,
                                            ck_tile::memory_operation_enum::atomic_add>{{}});
            }}
        }};

        BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
        return ave_time;
    }}
}};
"""
        return kernel_name, instance_code

    def run(self, num_workers=None):
        """Run the builder to generate individual kernel files"""
        # Generate individual kernel files
        self.generate_individual(num_workers)

    def generate_individual(self, num_workers=None):
        """Generate individual kernel files for separate compilation with parallel processing"""
        if num_workers is None:
            num_workers = min(
                multiprocessing.cpu_count(), 8
            )  # Limit to avoid memory issues

        tile_configs = self._get_tile_configs()
        trait_combos = self._generate_trait_combinations()
        k_block_per_cu = self.config.get("k_block_per_cu")
        if k_block_per_cu is None:
            k_block_per_cu = 1

        # Prepare work items for parallel processing
        work_items = []
        for tile_config in tile_configs:
            for trait_combo in trait_combos:
                work_items.append(
                    (
                        tile_config,
                        trait_combo,
                        k_block_per_cu,
                        self.working_path,
                        self.gpu_target,
                        self.datatype,
                        self.layout,
                        self.elementwise_function,
                        self.config_json,
                    )
                )

        print(
            f"Generating {len(work_items)} individual kernel files using {num_workers} workers..."
        )
        print(f"  Tile configs: {len(tile_configs)}")
        print(f"  Trait combinations: {len(trait_combos)}")
        print(f"  Total kernels: {len(work_items)}")

        # Show first few work items for debugging
        if work_items:
            print("  First work item example:")
            tile_config, trait_combo = work_items[0][:2]
            print(f"    Tile config: {tile_config}")
            print(f"    Trait combo: {trait_combo[:3]}")  # Show first 3 traits

        # Process work items in parallel
        kernel_list = []
        completed = 0

        with concurrent.futures.ProcessPoolExecutor(
            max_workers=num_workers
        ) as executor:
            # Submit all work items
            print(f"  Submitting {len(work_items)} tasks to executor...")
            future_to_item = {
                executor.submit(_generate_single_kernel_individual, item): item
                for item in work_items
            }
            print("  All tasks submitted, waiting for completion...")

            # Collect results with progress reporting
            for future in concurrent.futures.as_completed(future_to_item):
                completed += 1
                if completed % 100 == 0 or completed == len(work_items):
                    print(
                        f"  Progress: {completed}/{len(work_items)} kernels generated"
                    )

                try:
                    result = future.result()
                    if result:
                        kernel_list.append(result)
                except Exception as exc:
                    item = future_to_item[future]
                    print(f"Kernel generation failed for {item}: {exc}")

        # Sort kernel list for consistent ordering
        kernel_list.sort(key=lambda x: x[0])  # Sort by kernel name

        # Generate CMake include file for individual targets
        self._generate_cmake_individual_targets(kernel_list)

        print(
            f"Generated {len(kernel_list)} individual kernel files in {self.working_path}"
        )

    def _generate_cmake_individual_targets(self, kernel_list):
        """Generate CMake include file that creates individual targets"""
        cmake_code = f"""# Generated CMake file for individual GEMM Multi D targets
        # Datatype: {self.datatype}, Layout: {self.layout}
        """

        for kernel_name, trait_combo, tile_config in kernel_list:
            pipeline, epilogue, scheduler = trait_combo[:3]

            # Format tile config for CMake function
            tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_"
            tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_"
            tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}"

            trait_str = f"{pipeline}_{epilogue}_{scheduler}_" + "_".join(
                str(x) for x in trait_combo[3:]
            )

            cmake_code += f'create_individual_gemm_multi_d_target("{self.datatype}" "{self.layout}" "{trait_str}" "{tile_str}")\n'

        # Write CMake include file
        with open(
            self.working_path / "gemm_multi_d_individual_targets.cmake", "w"
        ) as f:
            f.write(cmake_code)


def _generate_single_kernel_individual(work_item):
    """Worker function to generate a single individual kernel file"""
    (
        tile_config,
        trait_combo,
        k_block_per_cu,
        working_path,
        gpu_target,
        datatype,
        layout,
        elementwise_function,
        config_json,
    ) = work_item

    # Create a temporary builder instance for this worker
    builder = GemmMultiDKernelBuilder(
        working_path,
        gpu_target,
        datatype,
        layout,
        elementwise_function,
        config_json,
    )

    try:
        kernel_name, instance_code = builder._generate_kernel_instance(
            tile_config, trait_combo, k_block_per_cu
        )

        # Create simplified filename without the "gemm_multi_d_" prefix
        # Remove "gemm_multi_d_" from the beginning of kernel_name for the filename
        simplified_name = kernel_name
        if simplified_name.startswith("gemm_multi_d_"):
            simplified_name = simplified_name[13:]  # Remove "gemm_multi_d_" prefix

        # Write individual header file
        header_file = working_path / f"gemm_multi_d_single_{simplified_name}.hpp"
        with open(header_file, "w") as f:
            f.write(instance_code)

        return (kernel_name, trait_combo, tile_config)
    except Exception as e:
        print(f"Error generating individual kernel: {e}")
        return None


def main():
    parser = argparse.ArgumentParser(
        description="GEMM Multi D kernel instance builder with parallel support"
    )
    parser.add_argument("--working_path", required=True, help="Working directory path")
    parser.add_argument("--gpu_target", required=True, help="GPU target architecture")
    parser.add_argument(
        "--datatype",
        required=True,
        choices=["fp16"],
        help="Data type",
    )
    parser.add_argument(
        "--layout",
        required=True,
        choices=["rcrr", "rrrr", "ccrr", "crrr"],
        help="Matrix layout",
    )
    parser.add_argument(
        "--elementwise_function",
        required=True,
        help="Specify what element wise function for D, e.g. mul, add, passthrough",
    )
    parser.add_argument("--config_json", help="Configuration JSON file")
    parser.add_argument(
        "--num_workers", type=int, help="Number of parallel workers (default: auto)"
    )
    parser.add_argument(
        "--gen_all_individual",
        action="store_true",
        help="Generate individual kernel files",
    )
    parser.add_argument(
        "--gen_single", action="store_true", help="Generate a single kernel file"
    )
    parser.add_argument("--kernel_name", help="Kernel name for single generation")
    parser.add_argument(
        "--tile_config", help="Tile configuration string for single generation"
    )
    parser.add_argument(
        "--trait_combo", help="Trait combination string for single generation"
    )
    parser.add_argument(
        "--list_kernels",
        action="store_true",
        help="List kernel configurations without generating files",
    )

    args = parser.parse_args()

    assert args.datatype in ["fp16"], (
        f"Invalid datatype string: {args.datatype} (supported datatypes are [fp16])"
    )

    layout_parts = args.layout.lower()
    assert len(layout_parts) == 4, (
        f"Invalid layout string: {args.layout} (must be 4 characters like 'rcrr' where r stands for row major and c stands for column major)"
    )
    assert layout_parts[0] in ["r", "c"] and layout_parts[1] in ["r", "c"], (
        f"Invalid matrix_a layout : {layout_parts[0]} or matrix_b layout: {layout_parts[1]} (matrix_a and matrix_b must be either 'r' for row major or 'c' for column major)"
    )
    assert layout_parts[2] == "r" and layout_parts[3] == "r", (
        f"Invalid matrix_c or d dimension in layout: {layout_parts[2]} andf {layout_parts[3]} (must be 'r' only as currently we are supporting only row major)"
    )

    # Elementwise function name validation
    elementwise_function = args.elementwise_function.lower()

    valid_functions = ["mul", "add", "passthrough"]
    if elementwise_function not in valid_functions:
        raise ValueError(
            f"Invalid elementwise function: {elementwise_function}. "
            f"Valid options are: {', '.join(valid_functions)}"
        )

    # Set the function name based on the elementwise function
    if elementwise_function == "mul":
        function_name = "MultiDMultiply"
    elif elementwise_function == "add":
        function_name = "MultiDAdd"
    elif elementwise_function == "passthrough":
        function_name = "PassThrough"

    args.elementwise_function = function_name

    # Create builder
    builder = GemmMultiDKernelBuilder(
        args.working_path,
        args.gpu_target,
        args.datatype,
        args.layout,
        args.elementwise_function,
        args.config_json,
    )

    if args.list_kernels:
        builder.write_kernel_list()
    elif args.gen_single:
        # Generate a single kernel file
        if not args.kernel_name or not args.tile_config or not args.trait_combo:
            parser.error(
                "--gen_single requires --kernel_name, --tile_config, and --trait_combo"
            )

        # Parse tile config
        tile_parts = args.tile_config.split("_")
        tile_dims = tile_parts[0].split("x")
        warp_dims = tile_parts[1].split("x")
        warp_tile_dims = tile_parts[2].split("x")

        tile_config = {
            "tile_m": int(tile_dims[0]),
            "tile_n": int(tile_dims[1]),
            "tile_k": int(tile_dims[2]),
            "warp_m": int(warp_dims[0]),
            "warp_n": int(warp_dims[1]),
            "warp_k": int(warp_dims[2]),
            "warp_tile_m": int(warp_tile_dims[0]),
            "warp_tile_n": int(warp_tile_dims[1]),
            "warp_tile_k": int(warp_tile_dims[2]),
        }

        # Parse trait combo
        trait_parts = args.trait_combo.split("_")
        trait_combo = (
            trait_parts[0],  # pipeline
            trait_parts[1],  # epilogue
            trait_parts[2],  # scheduler
            trait_parts[3] == "True",  # pad_m
            trait_parts[4] == "True",  # pad_n
            trait_parts[5] == "True",  # pad_k
            trait_parts[6] == "True",  # persistent
        )

        k_block_per_cu = builder.config.get("k_block_per_cu")
        if k_block_per_cu is None:
            k_block_per_cu = 1

        # Generate the kernel
        kernel_name, instance_code = builder._generate_kernel_instance(
            tile_config, trait_combo, k_block_per_cu
        )

        # Write the file
        simplified_name = kernel_name
        if simplified_name.startswith("gemm_multi_d_"):
            simplified_name = simplified_name[13:]

        header_file = (
            builder.working_path / f"gemm_multi_d_single_{simplified_name}.hpp"
        )
        with open(header_file, "w") as f:
            f.write(instance_code)

        print(f"Generated {header_file}")

    elif args.gen_all_individual:
        # Generate all individual kernel files
        builder.run(args.num_workers)
    else:
        parser.error(
            "Must specify one of: --list_kernels, --gen_all_individual, or --gen_single"
        )


if __name__ == "__main__":
    main()
